{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "015d1b5d-29b9-416e-91d2-12ce08cc12df",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:absl:Tensorflow library not found, tensorflow.io.gfile operations will use native shim calls. GCS paths (i.e. 'gs://...') cannot be accessed.\n",
      "WARNING:root:For using TFDSDataLoader please install tensorflow.\n",
      "WARNING:root:For using TFDSDataLoader please install tensorflow.\n",
      "WARNING:root:For using TFDSDataLoader please install tensorflow_datasets.\n"
     ]
    }
   ],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import jraph\n",
    "import numpy as np\n",
    "import time\n",
    "\n",
    "from dataclasses import dataclass\n",
    "from pprint import pprint\n",
    "\n",
    "from mlff.mdx.potential.mlff_potential_sparse import load_model_from_workdir\n",
    "from mlff.utils import jraph_utils\n",
    "from mlff.data import AseDataLoaderSparse\n",
    "from mlff.data import NpzDataLoaderSparse\n",
    "\n",
    "from so3lr import So3lr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6acdfb56-6536-4de3-9fad-dfbad9e887e4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "MLFF:root:Load data from data/ethanol.npz.\n",
      "MLFF:root:Calculate short range neighbors within cutoff=4.5 Ang.\n",
      "MLFF:root:Calculate long-range neighbors within cutoff_lr=inf Ang.\n",
      "0it [00:00, ?it/s]/Users/adil.kabylda/miniconda3/envs/so3lr_env/lib/python3.12/site-packages/ase/neighborlist.py:279: RuntimeWarning: invalid value encountered in cast\n",
      "  np.ceil(bin_size * nbins_c / face_dist_c).astype(int)\n",
      "2000it [00:00, 56390.59it/s]\n",
      "MLFF:root:... done!\n"
     ]
    }
   ],
   "source": [
    "dataloader = NpzDataLoaderSparse(\n",
    "    'data/ethanol.npz',\n",
    ")\n",
    "\n",
    "data, stats = dataloader.load(\n",
    "    cutoff=4.5,\n",
    "    cutoff_lr=np.inf,\n",
    "    calculate_neighbors_lr=True,\n",
    "    pick_idx=np.arange(50) # take the first 50 frames from the data set\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "605da817-0976-4987-b44a-244f2922da20",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GraphsTuple(nodes={'positions': array([[-0.3712776 , -0.45252128, -0.19895536],\n",
      "       [-0.72997075,  0.68413099,  0.79622273],\n",
      "       [ 0.98474324, -0.24509398, -0.47843776],\n",
      "       [-0.59848449, -1.4733098 ,  0.19563715],\n",
      "       [-1.00308552, -0.30490059, -1.10558217],\n",
      "       [-0.52242124,  1.73814903,  0.17984159],\n",
      "       [-1.6873976 ,  0.53773135,  1.23071945],\n",
      "       [ 0.08736885,  0.90074002,  1.4408355 ],\n",
      "       [ 1.214412  , -0.26932519, -1.46412348]]), 'atomic_numbers': array([6, 6, 8, 1, 1, 1, 1, 1, 1]), 'forces': array([[-26.65845641,  23.24921427,   2.1633537 ],\n",
      "       [ 33.23849668,  45.80338958, -77.83582996],\n",
      "       [ 43.15698382,  -6.54653358, -44.62892596],\n",
      "       [  6.9142078 ,   7.25861413,   0.43256605],\n",
      "       [  3.35740248,  -5.6187004 ,   5.12580848],\n",
      "       [-34.62580734, -52.22708068,  23.59973159],\n",
      "       [-23.99129633,   9.05796712,  15.11034046],\n",
      "       [ 15.82917769, -17.15955234,  35.86599296],\n",
      "       [-17.10727453,  -3.75687692,  40.18842496]]), 'hirshfeld_ratios': array([nan, nan, nan, nan, nan, nan, nan, nan, nan])}, edges={}, receivers=array([0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2,\n",
      "       2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5,\n",
      "       5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8,\n",
      "       8, 8, 8, 8, 8, 8]), senders=array([1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 2, 3, 0, 4, 6, 8, 7, 5, 1, 3,\n",
      "       0, 4, 0, 1, 2, 4, 5, 6, 7, 8, 6, 8, 7, 5, 3, 2, 1, 0, 0, 1, 2, 3,\n",
      "       4, 6, 7, 8, 8, 5, 4, 7, 2, 1, 0, 3, 8, 6, 4, 5, 2, 1, 0, 3, 6, 0,\n",
      "       1, 2, 3, 4, 5, 7]), globals={'energy': array([-97192.67788305]), 'stress': array([[nan, nan, nan, nan, nan, nan]]), 'dipole_vec': array([[nan, nan, nan]]), 'total_charge': array([0], dtype=int16), 'num_unpaired_electrons': array([0], dtype=int16)}, n_node=array([9]), n_edge=array([72]), n_pairs=array([72]), idx_i_lr=array([0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2,\n",
      "       2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5,\n",
      "       5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8,\n",
      "       8, 8, 8, 8, 8, 8]), idx_j_lr=array([1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 2, 3, 0, 4, 6, 8, 7, 5, 1, 3,\n",
      "       0, 4, 0, 1, 2, 4, 5, 6, 7, 8, 6, 8, 7, 5, 3, 2, 1, 0, 0, 1, 2, 3,\n",
      "       4, 6, 7, 8, 8, 5, 4, 7, 2, 1, 0, 3, 8, 6, 4, 5, 2, 1, 0, 3, 6, 0,\n",
      "       1, 2, 3, 4, 5, 7]))\n"
     ]
    }
   ],
   "source": [
    "# The data loader takes the input data file and converts each structure into a `jraph.GraphsTuple`.\n",
    "\n",
    "pprint(data[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "2ac875ba-fda4-4298-9c94-c3873e923360",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Time without jit = 33.34\n"
     ]
    }
   ],
   "source": [
    "# Create a fresh iterator.\n",
    "\n",
    "batch_size = 1\n",
    "\n",
    "# Batch the graphs.\n",
    "batched_graphs = jraph.dynamically_batch(\n",
    "    data,\n",
    "    n_node=stats['max_num_of_nodes'] * batch_size + 1,\n",
    "    n_edge=stats['max_num_of_edges'] * batch_size + 1,\n",
    "    n_graph=batch_size + 1,\n",
    "    n_pairs= stats['max_num_of_nodes'] * (stats['max_num_of_nodes'] - 1) * batch_size + 1\n",
    ")\n",
    "\n",
    "\n",
    "so3lr_calc = So3lr(\n",
    "    calculate_forces=True,\n",
    "    lr_cutoff=100\n",
    ")\n",
    "\n",
    "start = time.time()\n",
    "for graph_batch in batched_graphs:\n",
    "    # Transform the batched graph to inputs dict.\n",
    "    inputs = jraph_utils.graph_to_batch_fn(\n",
    "        graph_batch\n",
    "    )\n",
    "\n",
    "    output = so3lr_calc(inputs)\n",
    "\n",
    "end = time.time()\n",
    "\n",
    "print(f'Time without jit = {(end - start):.4g}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a0c6733e-d6cd-4f96-8cc9-7fba82e43181",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of iterations =  50\n",
      "Time for jit compile = 1.981\n",
      "Time for jit evaluation = 0.11\n",
      "Total time = 2.092\n"
     ]
    }
   ],
   "source": [
    "# Create a fresh iterator.\n",
    "\n",
    "batch_size = 1\n",
    "\n",
    "# Batch the graphs. We discuss below why this is neccessary when running with jax.jit.\n",
    "batched_graphs = jraph.dynamically_batch(\n",
    "    data,\n",
    "    n_node=stats['max_num_of_nodes'] * batch_size + 1,\n",
    "    n_edge=stats['max_num_of_edges'] * batch_size + 1,\n",
    "    n_graph=batch_size + 1,\n",
    "    n_pairs= stats['max_num_of_nodes'] * (stats['max_num_of_nodes'] - 1) * batch_size + 1\n",
    ")\n",
    "\n",
    "so3lr_calc_jit = jax.jit(so3lr_calc)\n",
    "\n",
    "i = 0\n",
    "start = time.time()\n",
    "\n",
    "predicted = []\n",
    "\n",
    "for graph_batch in batched_graphs:\n",
    "    # Transform the batched graph to inputs dict.\n",
    "    inputs = jraph_utils.graph_to_batch_fn(\n",
    "        graph_batch\n",
    "    )\n",
    "    \n",
    "    if i == 0:\n",
    "        start_compile = time.time()\n",
    "    \n",
    "    # Waits for computation to finish for accurate time measurement\n",
    "    output = jax.block_until_ready(so3lr_calc_jit(inputs))\n",
    "\n",
    "    if i == 0:\n",
    "        end_compile = time.time()\n",
    "    \n",
    "    graph_batch.nodes['forces_true'] = output['forces']\n",
    "    graph_batch.nodes['energy_true'] = output['energy']\n",
    "    graph_batch.nodes['dipole_vec_true'] = output['dipole_vec']\n",
    "    graph_batch.nodes['hirshfeld_ratios_true'] = output['hirshfeld_ratios']\n",
    "\n",
    "    # predicted.append(jraph.unbatch(graph_batch))\n",
    "    \n",
    "    i += 1\n",
    "    \n",
    "end = time.time()\n",
    "\n",
    "print('Total number of iterations = ', i)\n",
    "print(f'Time for jit compile = {(end_compile - start_compile):.4g}')\n",
    "print(f'Time for jit evaluation = {(end - end_compile):.4g}')\n",
    "print(f'Total time = {(end - start):.4g}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b36a47d9-bfbb-48dc-ba8d-0691033c1c30",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'dipole_vec': Array([[-0.15381902, -0.28700384, -0.02246571],\n",
      "       [ 0.        ,  0.        ,  0.        ]], dtype=float32),\n",
      " 'energy': Array([-14.779261,   0.      ], dtype=float32),\n",
      " 'forces': Array([[-3.2863107 , -2.3415163 ,  2.0469337 ],\n",
      "       [ 2.5173085 ,  1.5764256 , -0.6074784 ],\n",
      "       [-0.3820707 , -1.2290765 , -0.46262783],\n",
      "       [ 0.10792313,  1.1642374 ,  0.9063448 ],\n",
      "       [ 0.78319776,  1.6404856 , -0.9547439 ],\n",
      "       [ 0.11693724,  0.68606925,  1.1490139 ],\n",
      "       [ 0.65898407, -0.9101628 ,  0.16201867],\n",
      "       [-0.7330659 , -1.1155365 ,  0.01966219],\n",
      "       [ 0.21709657,  0.5290742 , -2.2591233 ],\n",
      "       [ 0.        ,  0.        ,  0.        ]], dtype=float32),\n",
      " 'hirshfeld_ratios': Array([0.7946877 , 0.76813865, 0.9233552 , 0.62973064, 0.58136225,\n",
      "       0.5984993 , 0.58792025, 0.5955206 , 0.59943235, 0.        ],      dtype=float32)}\n"
     ]
    }
   ],
   "source": [
    "# You can see the last entry has zeros, due to padding graph. See next section for explanation.\n",
    "pprint(output)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a523cfc6-a2fc-49f6-b162-b1b95c61a7ae",
   "metadata": {},
   "source": [
    "# Why jraph and batching hustle?\n",
    "\n",
    "You might wonder, why going through all the hustle with `jraph.dynamically_batch` and not just simply define an `inputs` dict for each structure `SO3LR` should be evaluated on. As illustrated above, `jax.jit` takes a long time during the first call, as it creates the computational graph and optimizes it via XLA afterwards. This is great, since we have seen that this can drastically speed up calculations even on CPU. There is a caveat though: `jax.jit` assumes static shapes of the input arrays. As an example, consider some function and its jitted counterpart\n",
    "\n",
    "```python\n",
    "fn = lambda A, B: do stuff ..\n",
    "fn_jit = jax.jit(fn)\n",
    "```\n",
    "\n",
    "Every time the shapes of `A` and `B` change, `jax.jit` will trigger a recompile, introducing a drastic computational overhead.\n",
    "\n",
    "\n",
    "When using `SO3LR` for structures of different size, this would for example correspond to\n",
    "\n",
    "```python\n",
    "inputs1['positions'].shape = (N1, 3)\n",
    "inputs2['positions'].shape = (N2, 3)\n",
    "inputs2['positions'].shape = (N2, 3)\n",
    "```\n",
    "\n",
    "where `N1`, `N2` and `N3` are the number of atoms of the first and second structure. Therefore, when calling `so3lr_jit` for each input there is triggered a recompile. To avoid such issues one can pad all arrays in `inputs` that depend on the input structure to some pre-defined value. Given some `iterator` over `jraph.GraphTuples` this is exactly what `jraph.dynamically_batch` is doing via `n_node`, `n_edge`, `n_graph` and `n_pair`. Afterwards, we transform the batched graph to the input using `jraph_utils.graph_to_inputs`. Note, that there is always one padding graph, which is used during the computation to dump values that depend to padded nodes, edges, and so on. For example, the code below will pad `batch_size = 2` structures per batch. Using the `stats` collected by the data loader, we can choose the remaining values based on the `batch_size`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b24c1084-153e-40ae-85a1-f16d96e36e5e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'numpy.ndarray'>\n",
      "<class 'numpy.ndarray'>\n",
      "<class 'numpy.ndarray'>\n",
      "<class 'numpy.ndarray'>\n",
      "<class 'numpy.ndarray'>\n",
      "<class 'numpy.ndarray'>\n",
      "<class 'numpy.ndarray'>\n",
      "<class 'numpy.ndarray'>\n",
      "<class 'numpy.ndarray'>\n",
      "<class 'numpy.ndarray'>\n",
      "<class 'numpy.ndarray'>\n",
      "<class 'numpy.ndarray'>\n",
      "<class 'numpy.ndarray'>\n",
      "<class 'numpy.ndarray'>\n",
      "<class 'numpy.ndarray'>\n",
      "<class 'numpy.ndarray'>\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "GraphsTuple(nodes={'atomic_numbers': None, 'forces': None, 'hirshfeld_ratios': None, 'positions': None}, edges={}, receivers=None, senders=None, globals={'dipole_vec': None, 'energy': None, 'num_unpaired_electrons': None, 'stress': None, 'total_charge': None}, n_node=None, n_edge=None, n_pairs=None, idx_i_lr=None, idx_j_lr=None)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch_size = 2\n",
    "\n",
    "batched_graphs = jraph.dynamically_batch(\n",
    "    data,\n",
    "    n_node=stats['max_num_of_nodes'] * batch_size + 1,\n",
    "    n_edge=stats['max_num_of_edges'] * batch_size + 1,\n",
    "    n_graph=batch_size + 1,\n",
    "    n_pairs= stats['max_num_of_nodes'] * (stats['max_num_of_nodes'] - 1) * batch_size + 1\n",
    ")\n",
    "\n",
    "# The arrays are numpy arrays at this point so they are on CPU and not on GPU.\n",
    "jax.tree.map(lambda x: print(type(x)), next(batched_graphs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "19775021-9421-40b1-aa65-ee0d32521fc9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'batch_segments': Array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2], dtype=int32),\n",
      " 'graph_mask': Array([ True,  True, False], dtype=bool),\n",
      " 'node_mask': Array([ True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "        True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "       False], dtype=bool),\n",
      " 'num_of_non_padded_graphs': Array(2, dtype=int32)}\n"
     ]
    }
   ],
   "source": [
    "pprint(jraph_utils.batch_info_fn(next(batched_graphs)))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "so3lr",
   "language": "python",
   "name": "so3lr"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
